#function that turns negative numbers into NAs. Also can drop outliers
neg_to_na <- function(vec, outlier_top = NULL, outlier_bot = NULL){
  vec[vec < 0] <- NA
  if (!is.null(outlier_top)) vec[vec > outlier_top] <- NA
  if (!is.null(outlier_bot)) vec[vec < outlier_bot] <- NA
  return(vec)
}

#function that produces clustered standard errors (for use with stargazer tables)
calc.ses.cluster <- function(model, cluster, keep = NULL, data = NULL){
  
  if(is.null(data)) cluster <- model$data[, cluster] else cluster <- data[, cluster]
  
  vcov_mat <- cluster.vcov(model, cluster)
  
  return(sqrt(diag(vcov_mat)))
}

################################################################################

#function that performs certain tests when there are multiple groups
multiple_tests <- function(data, vars, group_var, test){
  test_groups <- combn(unique(data %>% pull(group_var)), m = 2, simplify = F)
  names(test_groups) <- sapply(test_groups, function(x) paste(x[1], x[2],
                                                              sep = " vs "))
  fun_name <- deparse(substitute(test))
  
  lapply(data[, vars], function(x){
    lapply(test_groups, function(y) {
      if(fun_name == "t.test"){
        test(x[data[, group_var] == y[1]],
             x[data[, group_var] == y[2]])
      } else {
        test(x[data[[group_var]] %in% y], 
             data[[group_var]][data[[group_var]] %in% y])
      }
      
    })
  })
}

################################################################################

#simple 2-var cross tab
cross_tab <- function(x, y) {
  table(x, y) %>% 
    as.data.frame() %>% spread(key = y, value = Freq)
}

################################################################################

#function to create basic balance table
balance_table <- function(data, group_var, ..., passed = F){
  
  if(!passed){
    group_var <- enquo(group_var)
  }
  vars <- quos(...)
  
  group_var_name <- rlang::as_name(group_var)
  
  # if group_var is factor, convert to character
  if(is.factor(data[[group_var_name]])){
    data[[group_var_name]] <- as.character(data[[group_var_name]])
  }
  
  #get combination of test groups
  uniq_vals <- data |>  
    arrange(!!group_var) |>  
    pull(!!group_var) |>
    unique()
  test_groups <- combn(uniq_vals, m = 2, simplify = F)
  
  #give test groups meaningful names
  names(test_groups) <- sapply(test_groups, function(x) paste(x[1], x[2],
                                                              sep = " vs "))
  
  
  map_dfr(vars, function(x){

    group_mean <- data %>% 
      group_by(!!group_var) %>%
      summarise(mean = mean(!!x, na.rm = T))
    
    p_vals <- map_dbl(test_groups, function(y){

      vec <- data %>% pull(!!x)
      index1 <- (data %>% pull(!!group_var)) == y[[1]]
      index2 <- (data %>% pull(!!group_var)) == y[[2]]
      
      test_vec <- t.test(vec[index1],
                         vec[index2]) %>% broom::tidy()
      
      test_vec[['p.value']]
      
    })
    
    balance_vec <- data.frame(deparse(x), 
                              matrix(group_mean$mean, ncol = nrow(group_mean)), 
                              matrix(p_vals, ncol = length(p_vals)))
    
    balance_vec <- as.matrix(balance_vec, ncol = length(balance_vec)) %>% 
      as.data.frame()

    names(balance_vec) <- c("Variable", 
                            group_mean[[1]],
                            names(test_groups))
    
    balance_vec$Variable <- str_remove(balance_vec$Variable, "~")
    
    #turn all back into numeric
    balance_vec <- balance_vec %>% 
      dplyr::mutate_at(.vars = vars(-Variable), as.numeric) 
    
    balance_vec
  })
}

################################################################################

#same as above but cluster robust standard errors
balance_table_cluster <- function(data, group_var, cluster_var, ..., passed = F){
  
  if(!passed){
    group_var <- enquo(group_var)
    cluster_var <- enquo(cluster_var)
  }
  vars <- quos(...)
  
  group_var_name <- rlang::as_name(group_var)
  
  # if group_var is factor, convert to character
  if(is.factor(data[[group_var_name]])){
    data[[group_var_name]] <- as.character(data[[group_var_name]])
  }
  
  #get combination of test groups
  uniq_vals <- data |>  
    arrange(!!group_var) |>  
    pull(!!group_var) |>
    unique()
  test_groups <- combn(uniq_vals, m = 2, simplify = F)
  
  #give test groups meaningful names
  names(test_groups) <- sapply(test_groups, function(x) paste(x[1], x[2],
                                                              sep = " vs "))
  cluster_vec <- data %>% pull(!!cluster_var)
  group_vec <- data %>% pull(!!group_var)
  
  map_dfr(vars, function(x){
    
    group_mean <- data %>% 
      group_by(!!group_var) %>%
      summarise(mean = mean(!!x, na.rm = T))
    
    var_vec <- data %>% pull(!!x)
    
    p_vals <- map_dbl(test_groups, function(y){
      index <- (group_vec == y[[1]] | group_vec == y[[2]])
      
      test_vec <- Hmisc::t.test.cluster(var_vec[index], cluster_vec[index],
                                        group_vec[index])
      
      test_vec[20, 1]
      
    })

    balance_vec <- data.frame(deparse(x), 
                              matrix(group_mean$mean, ncol = nrow(group_mean)), 
                              matrix(p_vals, ncol = length(p_vals)))
    
    balance_vec <- as.matrix(balance_vec, ncol = length(balance_vec)) %>% 
      as.data.frame()
    
    names(balance_vec) <- c("Variable", 
                            group_mean[[1]],
                            names(test_groups))
    
    balance_vec$Variable <- str_remove(balance_vec$Variable, "~")
    
    #turn all back into numeric
    balance_vec <- balance_vec %>% 
      dplyr::mutate_at(.vars = vars(-Variable), as.numeric) 
    
    balance_vec
  })
}

################################################################################

#function to plot revenue for the treatment groups, with confidence bounds
plot_revenue_tgs <- function(data, y, time_var = year_month, 
                             group_var = treatment_status,
                             conf = .05, conf_int = T,
                             point_shape = T, point_size = 1){
  
  time_var <- enquo(time_var)
  group_var <- enquo(group_var)
  y <- enquo(y)
  
  plot_dat <- data %>% group_by(!!group_var, !!time_var) %>% 
    summarize(estimate = mean(!!y, na.rm = T),
              sd = sd(!!y, na.rm = T),
              n = sum(!is.na(!!y))) %>% 
    filter(n > 0) %>% 
    mutate(UB = estimate + qt(1 - (conf/2), n - 1) * (sd/sqrt(n)),
           LB = estimate - qt(1 - (conf/2), n - 1) * (sd/sqrt(n)),
           LB = ifelse(LB < 0, 0, LB)) %>% 
    arrange(!!time_var)
  
  rev_plot <- ggplot(data = plot_dat, 
                     aes(x = !!time_var, y = estimate, color = !!group_var, 
                         group = !!group_var)) + geom_line() + theme_bw() +
    theme(axis.text.x = element_text(angle = 270, size = 7.5, vjust = .5)) 
  
  if(point_shape){
    rev_plot <- rev_plot + geom_point(map = aes(shape = !!group_var),
                                      size = point_size)
  } else {
    rev_plot <- rev_plot + geom_line(size = point_size)
  }
  
  if(conf_int){
    rev_plot <- rev_plot + 
      geom_errorbar(aes(ymin = LB, ymax = UB))
  }  
  
  rev_plot
}


#function to plot revenue for a market, purely for convenience
plot_revenue <- function(data, index, y, time_var = year_month){
  
  time_var <- enquo(time_var)
  y <- enquo(y)
  
  data <- data[index, ] %>% arrange(!!time_var)
  
  data$time <- 1:length(data %>% pull(!!time_var))
  
  rev_plot <- ggplot(data = data, 
                     aes(x = time, y = !!y, group = 1)) +
    scale_x_continuous(breaks = 1:nrow(data), labels = data %>% pull(!!time_var)) + 
    geom_point(size = 1) + geom_line() + theme_bw() +
    theme(axis.text.x = element_text(angle = 270, size = 7.5, vjust = .5))
  
  rev_plot
  
}

#function to plot revenue for all markets, purely for convenience
plot_revenue_all <- function(market_month = market_month, 
                             y, y_name, y_max, y_lim = TRUE){
  y <- enquo(y)
  
  market_treat <- unique(market_month %>% 
                           select(official_name, treatment_status, District)) %>%
    arrange(District, official_name) %>% 
    as.data.frame()
  
  plot_list <- list()
  
  for(i in 1:nrow(market_treat)){
    rev_plot <- plot_revenue(market_month, market_month$official_name == market_treat[i, 1],
                             y = !!y) +
      labs(y = y_name, x = "Year and Month",
           title = paste0(market_treat[i, 1], " - ", market_treat[i, 3],
                          "\n(", market_treat[i, 2], ")"))
    
    if(y_lim){
      rev_plot <- rev_plot + coord_cartesian(ylim = c(0, y_max))
    } else {
      rev_plot <- rev_plot + ylim(c(0, NA))
    }
    
    if(market_treat$District[i] == "Mulanje"){
      if(market_treat$official_name[i] == "Wendewende"){
        rev_plot <- rev_plot + geom_vline(xintercept = 7, linetype = "dotdash", 
                                          color = "red")
      } else {
        rev_plot <- rev_plot + geom_vline(xintercept = 4, linetype = "dotdash", 
                                          color = "red")
      }
    }
    
    print(rev_plot)
  }
}

################################################################################

#function to count number of non-nas
non_nas <- function(x) {
  sum(!is.na(x))
}

#function to get standard deviation and number of observations per list of variables in grouped data
get_n_sd_grouped <- function(data, group,...){
  
  group <- enquo(group)
  
  n_group <- data %>% 
    group_by(!!group) %>% summarise_at(vars(...), 
                                       .fun = non_nas) 
  
  variable_names <- names(n_group[2:ncol(n_group)])
  
  n_sd_grouped <- n_group[2:ncol(n_group)] %>% 
    map_dfr(.f = function(x) {
      
      ave_n_resp = mean(x)
      sd_n_resp = sd(x)
      
      data.frame(ave_n_resp = ave_n_resp,
                 sd_n_resp = sd_n_resp)
    })
  
  n_sd_grouped <- n_sd_grouped %>% mutate(Variable = variable_names) %>% 
    select(Variable, everything())
  
  return(list(n_sd_grouped = n_sd_grouped,
              n_group = n_group))
} 

################################################################################

#function to transpose a df (adapted from Stackoverflow: https://stackoverflow.com/questions/42790219/how-do-i-transpose-a-tibble-in-r)

transpose_df <- function(df) {
  t_df <- data.table::transpose(df[2:ncol(df)])
  #colnames(t_df) <- rownames(df[2:ncol(df)])
  rownames(t_df) <- colnames(df[2:ncol(df)])
  t_df <- t_df %>%
    tibble::rownames_to_column(.data = .) %>%
    tibble::as_tibble(.)
  names(t_df) <- c("Column", df[[1]])
  return(t_df)
}

################################################################################

pairwise.reg.test <- function(
    formula,
    treat_var,
    data,
    p.adjust.method = p.adjust.methods,
    cluster = FALSE,
    cluster_vars = NULL) {
  
  p.adjust.method <- match.arg(p.adjust.method)
  
  is_fct <- data |> pull({{ treat_var }}) |> is.factor()
  
  if(!is_fct){
    data <- data |> 
      mutate("{{treat_var}}" := as.factor({{ treat_var }}))
  }
  
  ls <- data |> 
    pull({{ treat_var }}) |> 
    levels()
  
  get_reg_pvals <- function(lev) {
    
    data <- data |> 
      mutate("{{treat_var}}" := fct_relevel({{ treat_var }}, lev))
    
    mod <- lm(formula, data) 
    
    results <- mod |> 
      tidy() 
    
    #adjust for clustering
    if(cluster) {
      results$p.value <- coeftest(mod, 
                                  vcov = function(x) cluster.vcov(x, data[cluster_vars])) |> 
        tidy() |> 
        pull(p.value)
    }
    
    results <- results |> 
      filter(str_detect(term, as.character(enexpr(treat_var)))) |> 
      select(term, estimate, p.value) |> 
      mutate(base_lev = lev,
             term = str_remove(term, as.character(enexpr(treat_var))))
    
    results  
  }
  
  # get naive p_values
  pvals <- map(ls, get_reg_pvals) |> 
    list_rbind()
  
  # concatenate, then order, then combine
  pvals <- pvals |>
    mutate(combined = map2(base_lev, term, c),
           ord = map(combined, order),
           combined = map2(combined, ord, \(.x, .y) .x[.y]),
           combined = map_chr(combined, \(x) paste0(x, collapse = " vs "))
    )
  # drop ord column
  pvals$ord <- NULL
  
  # remove duplicated comparisons
  pvals <- pvals |> 
    filter(!duplicated(combined))
  
  #p adjust
  pnam <- paste0("p_", p.adjust.method)
  pvals <- pvals |>
    mutate({{pnam}} := p.adjust(p.value, p.adjust.method))
  
  pvals  
}

################################################################################
# coefplot from dataframe
make_tidys_list <- function(mods_list, names = NULL, 
                            cluster_ids = vendor_end$market) {
  
  get_coeftest <- function(x) {
    coeftest(x,
             vcov = \(x) cluster.vcov(x, cluster_ids))
  }
  
  tidy_list <- map(mods_list,
                   get_coeftest) |> 
    map(\(x) tidy(x, conf.int = TRUE))
  
  if(!is.null(names)) {
    names(tidy_list) <- names  
  }
  
  tidy_list
  
}

combine_tidy_dfs <- function(tidy_dfs) {
  
  add_model_name <- function(x, name) {
    x$model <- name
    x
  }
  
  imap(tidy_dfs, add_model_name) |> 
    list_rbind()
  
}

make_coefplot <- function(coefs_df,
                          title = "Coefficients Plot",
                          legend_title = "Term",
                          x_order = NULL,
                          xlab = "Dependent Variable",
                          ylab = "Coefficient Estimate") {
  one_model <- is.null(coefs_df$model) || 
    (length(unique(coefs_df$model)) == 1)
  
  coefs_df$term <- factor(coefs_df$term,
                          levels = unique(coefs_df$term))
  
  if(!one_model) {
    if(is.null(x_order)) {
      coefs_df$model <- factor(coefs_df$model,
                               levels = unique(coefs_df$model))
    } else {
      coefs_df$model <- factor(coefs_df$model,
                               levels = x_order)
    }
  } 
  
  p <- ggplot(coefs_df, 
              aes(x = model,
                  y = estimate,             
                  ymin = conf.low,
                  ymax = conf.high)) +
    geom_hline(yintercept = 0, 
               linewidth = 1,
               linetype = "3313") +
    theme_bw() +
    labs(title = title,
         x = xlab,
         y = ylab)  + 
    geom_point(aes(color = term),
               position = position_dodge(width = 0.2)) +
    geom_errorbar(aes(color = term),
                  width = 0.2,
                  position = position_dodge(width = 0.2) ) +
    labs(color = legend_title) +
    scale_color_viridis_d(begin = 0,
                          end = 0.75,
                          option = "magma")
  
  p
  
}




# convenience function that combines all three elements
mods_list_to_coefplot <- function(mods_list, names, 
                                  cluster_ids = vendor_end$market) {
  make_tidys_list(mods_list, names, cluster_ids) |>
    combine_tidy_dfs() |>
    filter(term %in% c("BU", "TD", "Both")) |> 
    make_coefplot()
}


################################################################################

# function to calculate linear hypotheses for a model and alternatively correct
# them for MHTC
get_linhyps_ps <- function(model, lin_hyps, cluster = NULL, correct = FALSE,
                           p.adjust.method = p.adjust.methods) {
  
  p.adjust.method <- match.arg(p.adjust.method)
  
  mod_name <- deparse(substitute(model))
  
  
  if(is.null(cluster)){
    vcov_mat <- vcov(model)
  } else {
    vcov_mat <- cluster.vcov(model, cluster)
  }
  
  f_p <- function(lin_hyp) {
    f_test <- linearHypothesis(
      model = model,
      hypothesis.matrix = lin_hyp,
      vcov. = vcov_mat
    )
    
    data.frame(lin_hyp = lin_hyp,
               p = f_test$`Pr(>F)`[2])
    
  }
  
  p_vals <- map(lin_hyps, .f = f_p) |> 
    list_rbind()
  
  p_vals <- p_vals |> 
    mutate(model = mod_name) |> 
    relocate(model, .before = everything())
  
  if(isTRUE(correct)) {
    
    pnam <- paste0("p_", p.adjust.method)
    p_vals <- p_vals |>
      mutate({{pnam}} := p.adjust(p, p.adjust.method))
  }
  
  p_vals
  
}

# function to calculate linear hypotheses for a set of models and alternatively
# correct them for MHTC

get_linhyps_ps_mods <- function(models,
                                lin_hyps,
                                clusters = NULL,
                                correct = NULL,
                                p.adjust.method = p.adjust.methods) {
  
  p.adjust.method <- match.arg(p.adjust.method)
  
  if(isTRUE(correct)) {
    correct <- "overall"
    warning("`correct` is `TRUE`, defaulting to `'overall'`")
  }
  
  if(correct == "within"){
    correct <- rep(TRUE, times = length(models))
    overall <- FALSE
  } else if (correct == "overall") {
    correct <- FALSE
    overall <- TRUE
  } else if(length(correct) == 1 && isFALSE(correct)) {
    correct <- FALSE
    overall <- FALSE
  } 
  
  if(is.null(names(models))) {
    nams <- seq_along(models)
  } else {
    nams <- names(models)
  }
  
  get_linhyp_df <- function(model,
                            nam,
                            cluster,
                            correct,
                            p.adjust.method) {
    get_linhyps_ps(model = model,
                   lin_hyps = lin_hyps,
                   cluster = cluster,
                   correct = correct,
                   p.adjust.method = p.adjust.method) |> 
      mutate(model = nam)
  }
  
  mods_p_vals <- pmap(list(model = models, 
                           nam = nams,
                           cluster = clusters,
                           correct = correct,
                           p.adjust.method = p.adjust.method),
                      get_linhyp_df) |> 
    list_rbind()
  
  if(isTRUE(overall)) {
    pnam <- paste0("p_", p.adjust.method)
    mods_p_vals <- mods_p_vals |>
      mutate({{pnam}} := p.adjust(p, p.adjust.method))
  }
  
  mods_p_vals
}

################################################################################

# function to calculate control group mean for a series of variables

get_control_means <- function(data, control_var, ..., control_val = NULL, 
                              digits = 3) {
  
  out <- data |> 
    group_by({{ control_var }}) |> 
    summarise(across( c(...), \(x) mean(x, na.rm = TRUE))) |> 
    mutate(across(where(is.numeric), ~ num(.,digits = digits)))
  
  if (is.null(control_val)){
    out
  } else {
    out |>
      filter({{control_var}} %in% control_val)
    
  }
  
}

################################################################################

#function that creates summary table
create_sum_stats_tab <- function(data, group_var, ..., cluster = T, 
                                 cluster_var, print_table = T,
                                 significance = T,
                                 digits = 3, alpha = .05, 
                                 caption.placement = "top",
                                 size = "scriptsize",
                                 var_names = NULL,
                                 title = "Summary Statistics Table",
                                 file = getOption("xtable.file", "")){
  
  
  num <- function(x){
    sum(!is.na(x))
  }
  
  group_var <- dplyr::enquo(group_var)
  cluster_var <- dplyr::enquo(cluster_var)
  vars <- dplyr::quos(...)
  
  # getting means by cluster and then marking sig differences
  if(cluster){
    bal_tabl <- balance_table_cluster(data, group_var, cluster_var,
                                      ..., passed = T)
  } else {
    bal_tabl <- balance_table(data, group_var, ..., passed = T)
  }
  
  #log vector of p-value columns
  vs_log <- grepl("vs", names(bal_tabl))
  vs_names <- names(bal_tabl)[vs_log]
  
  #log vector of only group mean columns
  mean_log <- !grepl("vs|Variable", names(bal_tabl))
  mean_names <- names(bal_tabl)[mean_log]
  
  if(significance){
    bal_tabl[mean_log] <- apply(bal_tabl[mean_log], 2, 
                                round, digits = digits)
    
    #adding identifying superscripts to mean names in table
    names(bal_tabl)[mean_log] <- paste(mean_names, "$^",
                                       1:length(mean_names), "$",
                                       sep = "")
    
    #turn means into characters
    
    #figuring out comparison pairs
    comp_mat <- str_split_fixed(vs_names,
                                " vs ",
                                n = 2)
    comp_mat <- apply(comp_mat, c(1,2),
                      function(x){
                        (1:length(mean_names))[mean_names %in% x]
                      })
    comp_mat <- cbind(comp_mat, (1:ncol(bal_tabl))[vs_log])
    
    #adding significance superscripts to variable means
    add_sig <- function(x){
      sig_list <- vector(mode = "list", length = length(mean_names))
      
      for(i in 1:nrow(comp_mat)){
        if(is.na(x[comp_mat[i, 3] - 1])) next
        if((x[comp_mat[i, 3] - 1]) <= alpha){
          sig_list[[comp_mat[i, 1]]] <- c(sig_list[[comp_mat[i, 1]]],
                                          comp_mat[i, 2])
          sig_list[[comp_mat[i, 2]]] <- c(sig_list[[comp_mat[i, 2]]],
                                          comp_mat[i, 1])          
        }
      }
      
      sig_list <- lapply(sig_list, function(x){
        if(length(x) == 0)
          return(x)
        else if(length(x) == 1)
          return(paste0("$^", x, "$"))
        else if(length(x) > 1){
          return(paste0("$^{", paste(x, collapse = ","), "}$"))
        }
      })
      
      for(j in 1:length(sig_list)){
        x[j] <- paste0(x[j], sig_list[[j]])
      }
      x[1:length(mean_names)]
    }
    
    bal_tabl[mean_log] <- t(apply(bal_tabl[2:ncol(bal_tabl)],
                                  1,
                                  add_sig))
    
    bal_tabl <- bal_tabl[!vs_log]
    
  } else{
    bal_tabl <- bal_tabl[!vs_log]
  }  
  
  #getting other summary statistics
  sum_stats <- purrr::map_dfr(vars, function(x){
    data %>% 
      select(!!x) %>% 
      summarise_all(list(~mean(., na.rm = T),
                         ~sd(., na.rm = T),
                         ~min(., na.rm = T),
                         ~max(., na.rm = T),
                         ~num(.)))
  })
  names(sum_stats) <- c("Overall Mean", "SD", "Min", "Max", "N")
  
  sum_stats <- cbind(bal_tabl, sum_stats) %>% 
    select(Variable, `Overall Mean`, everything())
  
  if(!is.null(var_names)) sum_stats$Variable <- var_names
  
  
  if(print_table){
    if(significance){
      title <- paste0(title, ". Superscripts in column names identify groups. Superscripts in cells indicate that a value is significantly different from the value for the superscripted group")
    }
    
    alignment <- c("l", "l", "|l", rep("l", length(mean_names)), "|l",
                   rep("l", 3))
    
    print(xtable::xtable(sum_stats, header = F,
                         digits = digits,
                         caption = title,
                         align = alignment),
          comment = F, include.rownames = F,
          sanitize.text.function = sanitize_allow_latex,
          caption.placement = caption.placement,
          size = size,
          file = file)
    
  }
  
  sum_stats
  
}

################################################################################

#adapted sanitize function from xtable
sanitize_allow_exp <- function (str, type = "latex") 
{
  if (type == "latex") {
    result <- str
    result <- gsub("\\\\", "SANITIZE.BACKSLASH", 
                   result)
    #result <- gsub("$", "\\$", result, fixed = TRUE)
    result <- gsub(">", "$>$", result, fixed = TRUE)
    result <- gsub("<", "$<$", result, fixed = TRUE)
    result <- gsub("|", "$|$", result, fixed = TRUE)
    #result <- gsub("{", "\\{", result, fixed = TRUE)
    #result <- gsub("}", "\\}", result, fixed = TRUE)
    result <- gsub("%", "\\%", result, fixed = TRUE)
    result <- gsub("&", "\\&", result, fixed = TRUE)
    result <- gsub("_", "\\_", result, fixed = TRUE)
    result <- gsub("#", "\\#", result, fixed = TRUE)
    #result <- gsub("^", "\\verb|^|", result, 
    #    fixed = TRUE)
    result <- gsub("~", "\\~{}", result, fixed = TRUE)
    result <- gsub("SANITIZE.BACKSLASH", "$\\backslash$", 
                   result, fixed = TRUE)
    return(result)
  }
  else {
    result <- str
    result <- gsub("&", "&amp;", result, fixed = TRUE)
    result <- gsub(">", "&gt;", result, fixed = TRUE)
    result <- gsub("<", "&lt;", result, fixed = TRUE)
    return(result)
  }
}

################################################################################

sanitize_allow_latex <- function (str, type = "latex") 
{
  if (type == "latex") {
    result <- str
    result <- gsub("\\\\", "SANITIZE.BACKSLASH", 
                   result)
    #result <- gsub("$", "\\$", result, fixed = TRUE)
    result <- gsub(">", "$>$", result, fixed = TRUE)
    result <- gsub("<", "$<$", result, fixed = TRUE)
    result <- gsub("|", "$|$", result, fixed = TRUE)
    #result <- gsub("{", "\\{", result, fixed = TRUE)
    #result <- gsub("}", "\\}", result, fixed = TRUE)
    result <- gsub("%", "\\%", result, fixed = TRUE)
    result <- gsub("&", "\\&", result, fixed = TRUE)
    result <- gsub("_", "\\_", result, fixed = TRUE)
    result <- gsub("#", "\\#", result, fixed = TRUE)
    #result <- gsub("^", "\\verb|^|", result, 
    #    fixed = TRUE)
    result <- gsub("~", "\\~{}", result, fixed = TRUE)
    result <- gsub("SANITIZE.BACKSLASH", "$\\backslash$", 
                   result, fixed = TRUE)
    result <- gsub("$\\backslash$footnote", "\\footnote",
                   result, fixed = TRUE)
    result <- gsub("$\\backslash$textbf", "\\textbf",
                   result, fixed = TRUE)
    result <- gsub("$\\backslash$protect", "\\protect",
                   result, fixed = TRUE)  
    result <- gsub("$\\backslash$$\\backslash$", "\\\\",
                   result, fixed = TRUE)
    result <- gsub("$\\backslash$makecell", "\\makecell",
                   result, fixed = TRUE)
    result <- 
      return(result)
  }
  else {
    result <- str
    result <- gsub("&", "&amp;", result, fixed = TRUE)
    result <- gsub(">", "&gt;", result, fixed = TRUE)
    result <- gsub("<", "&lt;", result, fixed = TRUE)
    return(result)
  }
}

################################################################################

sanitize_passthrough <- function (str, type = "latex") 
{
  str
}

################################################################################

# define xtable alignmentrs

xtabl_tg_align <- c("l", "p{0.4\\linewidth}", "p{0.1\\linewidth}", 
                    "p{0.1\\linewidth}", "p{0.1\\linewidth}",
                    "p{0.1\\linewidth}")

################################################################################

# function to calculate cluster p-values 
calc_p_cluster <- function(model, test_vars = names(coef(model)), 
                           cluster_var = "market", cluster_data = vendor_end){
  
  coefs <- model %>% tidy() %>% 
    filter(term %in% test_vars) %>% 
    pull(estimate)
  
  se <- data.frame(se = calc.ses.cluster(model, cluster_var,
                                         data = cluster_data)) %>% 
    rownames_to_column() %>% 
    filter(rowname %in% test_vars) %>% 
    pull(se)
  t <- coefs/se
  p <- 2 * pnorm(abs(t), mean = 0, sd = 1, 
                 lower.tail = FALSE, log.p = FALSE)
  p
}

# function to format and create latex for linear hypothesis tables
# version with models as rows
print_linhyps_table_modrows <- function(lh, 
                                        format_cnames = TRUE, 
                                        adj_only = FALSE,
                                        mod_names = NULL,
                                        caption = NULL,
                                        label = NULL, 
                                        align = NULL,
                                        digits = 3,
                                        ...) {
  
  if(isTRUE(adj_only)) {
    lh$p <- NULL
  }
  
  lh <- lh |> 
    pivot_wider(
      id_cols = model,
      names_from = lin_hyp,
      values_from = starts_with("p")
    )
  
  if(isTRUE(format_cnames)) {
    
    cnames <- names(lh)
    p_cols <- stringr::str_detect(cnames, "p_")
    
    format_cname_func <- function(cname) {
      cname_frmtd <- stringr::str_split_1(cname, "_")
      end <- length(cname_frmtd)
      cname_frmtd <- glue::glue(
        "{cname_frmtd[end]}",
        " ",
        "({stringr::str_c(cname_frmtd[-end], collapse = '-adj: ')})")
    }
    
    cnames[p_cols] <- map_chr(cnames[p_cols],
                              format_cname_func)
    
    names(lh)[p_cols] <- cnames[p_cols]
    
  }
  
  if(!is.null(mod_names)) {
    lh <- lh |> 
      mutate(model = mod_names)
  }
  
  
  if(is.null(align)) {
    align <- c("l", "l", rep("c", times = (ncol(lh) - 1) ))
  }
  
  if(is.null(label)) {
    label <- "temp:lab"
  }
  
  lh <- lh |> 
    rename(Outcome = model)
  
  lh |> 
    xtable::xtable(caption = caption,
                   label = label,
                   align = align,
                   digits = digits) |> 
    xtable::print.xtable(comment = FALSE, ...)
  
}

# similar to above, but with linear hyps as rows
print_linhyps_table <- function(lhs, 
                                #format_cnames = TRUE, 
                                adj_only = FALSE,
                                mod_names = NULL,
                                caption = NULL,
                                label = NULL, 
                                align = NULL,
                                digits = 3,
                                ...) {
  
  lh <- lhs |> 
    pivot_longer(
      cols = starts_with("p"),
      names_to = "p_type",
      values_to = "p"
    ) |> 
    pivot_wider(
      id_cols = c(lin_hyp, p_type),
      names_from = model,
      values_from = p
    )
  
  if(isTRUE(adj_only)) {
    lh <- lh |> 
      filter(p_type != "p") |> 
      mutate(p_type = NULL)
  } else {
    lh <- lh |> 
      mutate(p_type = stringr::str_split_i(p_type, "_", 2),
             p_type = stringr::str_to_title(p_type),
             lin_hyp = if_else(is.na(p_type),
                               lin_hyp,
                               stringr::str_c("\\textit{",
                                              "$p$-adj: ",
                                              p_type,
                                              "}")),
             p_type = NULL
      )
  }
  
  
  
  
  if(!is.null(mod_names)) {
    
    names(lh)[2:ncol(lh)] <- mod_names
    
  }
  
  
  if(is.null(align)) {
    align <- c("l", "l", rep("c", times = (ncol(lh) - 1) ))
  }
  
  if(is.null(label)) {
    label <- "temp:lab"
  }
  
  lh <- lh |> 
    rename(`Linear Hypothesis` = lin_hyp)
  
  lh |> 
    xtable::xtable(caption = caption,
                   label = label,
                   align = align,
                   digits = digits) |> 
    xtable::print.xtable(comment = FALSE, ...)
  
}

################################################################################

# convenience function to extract p-values for regular linear models 
get_p <- function(model, test_vars = names(coef(model))){
  
  model %>% 
    tidy() %>% 
    filter(term %in% test_vars) %>% 
    pull(p.value)
  
}

# convenience function for making an empty mhtc data frame
make_empty_mhtc_df <- function(){
  data.frame(Level = character(),
             Outcome = character(),
             Term = character(),
             Hypothesis = character(),
             p = numeric())
}

# function that combines the two above to make a MHTC correction table for a
# group of models, either at individual or market-level
make_mhtc_table <- function(ind_mods = NULL, 
                            mkt_mods = NULL,
                            hypothesis,
                            test_vars = c("BU", "TD", "Both"),
                            outcomes_ind,
                            outcomes_mkt = outcomes_ind,
                            n_corr = NULL,
                            cluster_var = "market",
                            cluster_data = vendor_end
                            
) {
  
  # getting p-values for ind-level models
  if(!is.null(ind_mods)){
    ps_ind <- unlist(lapply(ind_mods, calc_p_cluster, 
                            test_vars = test_vars,
                            cluster_var = cluster_var,
                            cluster_data = cluster_data))
    
    ps_corr_ind <- expand_grid(Level = "Individual",
                               Outcome = outcomes_ind,
                               Term = test_vars,
                               Hypothesis = hypothesis) %>% 
      mutate(p = ps_ind)
    
  } else {
    ps_corr_ind <- make_empty_mhtc_df()
  }
  
  if(!is.null(mkt_mods)){
    ps_mkt <- unlist(lapply(mkt_mods, get_p, test_vars = test_vars))
    
    ps_corr_mkt <- expand_grid(Level = "Market",
                               Outcome = outcomes_mkt,
                               Term = test_vars,
                               Hypothesis = hypothesis) %>% 
      mutate(p = ps_mkt)
    
  } else {
    ps_corr_mkt <- make_empty_mhtc_df()
  }
  
  # combine
  ps_corr <- bind_rows(ps_corr_ind, ps_corr_mkt)
  
  # get number of corrections
  if(is.null(n_corr)){
    n_corr <- nrow(ps_corr)
  }
  
  ps_corr <- ps_corr %>% 
    mutate(p = p, 
           p_holm = p.adjust(p, "holm"),
           p_BH = p.adjust(p, "BH"),
           survives_holm = case_when(
             p_holm <= 0.05 & p <= 0.05 ~ "Yes",
             p_holm > 0.05 & p <= 0.05 ~ "No"
           ),
           survives_BH = case_when(
             p_BH <= 0.05 & p <= 0.05 ~ "Yes",
             p_BH > 0.05 & p <= 0.05 ~ "No"
           )
    )
}